{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "hide_input": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/tjAIM7BOYhw?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#@title\n",
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/tjAIM7BOYhw?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "画像分類では、画像にラベルまたはクラスを割り当てます。テキストや音声の分類とは異なり、入力は\n",
    "画像を構成するピクセル値。損傷の検出など、画像分類には多くの用途があります\n",
    "自然災害の後、作物の健康状態を監視したり、病気の兆候がないか医療画像をスクリーニングしたりするのに役立ちます。\n",
    "\n",
    "このガイドでは、次の方法を説明します。\n",
    "\n",
    "1. [Food-101](https://huggingface.co/datasets/ethz/food101) データセットの [ViT](https://huggingface.co/docs/transformers/main/ja/tasks/model_doc/vit) を微調整して、画像内の食品を分類します。\n",
    "2. 微調整したモデルを推論に使用します。\n",
    "\n",
    "<Tip>\n",
    "\n",
    "このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/image-classification) を確認することをお勧めします。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate\n",
    "```\n",
    "\n",
    "Hugging Face アカウントにログインして、モデルをアップロードしてコミュニティと共有することをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Food-101 dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets、🤗 データセット ライブラリから Food-101 データセットの小さいサブセットを読み込みます。これにより、次の機会が得られます\n",
    "完全なデータセットのトレーニングにさらに時間を費やす前に、実験してすべてが機能することを確認してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "food = load_dataset(\"ethz/food101\", split=\"train[:5000]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`train_test_split` メソッドを使用して、データセットの `train` 分割をトレイン セットとテスト セットに分割します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "food = food.train_test_split(test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、例を見てみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,\n",
       " 'label': 79}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "food[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセット内の各例には 2 つのフィールドがあります。\n",
    "\n",
    "- `image`: 食品の PIL 画像\n",
    "- `label`: 食品のラベルクラス\n",
    "\n",
    "モデルがラベル ID からラベル名を取得しやすくするために、ラベル名をマップする辞書を作成します。\n",
    "整数への変換、またはその逆:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = food[\"train\"].features[\"label\"].names\n",
    "label2id, id2label = dict(), dict()\n",
    "for i, label in enumerate(labels):\n",
    "    label2id[label] = str(i)\n",
    "    id2label[str(i)] = label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで、ラベル ID をラベル名に変換できるようになりました。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'prime_rib'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id2label[str(79)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次のステップでは、ViT 画像プロセッサをロードして画像をテンソルに処理します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "\n",
    "checkpoint = \"google/vit-base-patch16-224-in21k\"\n",
    "image_processor = AutoImageProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "いくつかの画像変換を画像に適用して、モデルの過学習に対する堅牢性を高めます。ここでは torchvision の [`transforms`](https://pytorch.org/vision/stable/transforms.html) モジュールを使用しますが、任意の画像ライブラリを使用することもできます。\n",
    "\n",
    "画像のランダムな部分をトリミングし、サイズを変更し、画像の平均と標準偏差で正規化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor\n",
    "\n",
    "normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n",
    "size = (\n",
    "    image_processor.size[\"shortest_edge\"]\n",
    "    if \"shortest_edge\" in image_processor.size\n",
    "    else (image_processor.size[\"height\"], image_processor.size[\"width\"])\n",
    ")\n",
    "_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、変換を適用し、画像の `pixel_values` (モデルへの入力) を返す前処理関数を作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transforms(examples):\n",
    "    examples[\"pixel_values\"] = [_transforms(img.convert(\"RGB\")) for img in examples[\"image\"]]\n",
    "    del examples[\"image\"]\n",
    "    return examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセット全体に前処理関数を適用するには、🤗 Datasets `with_transform` メソッドを使用します。変換は、データセットの要素を読み込むときにオンザフライで適用されます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "food = food.with_transform(transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、`DefaultDataCollat​​or` を使用してサンプルのバッチを作成します。 🤗 Transformers の他のデータ照合器とは異なり、`DefaultDataCollat​​or` はパディングなどの追加の前処理を適用しません。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DefaultDataCollator\n",
    "\n",
    "data_collator = DefaultDataCollator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニング中にメトリクスを含めると、多くの場合、モデルのパフォーマンスを評価するのに役立ちます。すぐにロードできます\n",
    "🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) ライブラリを使用した評価方法。このタスクでは、ロードします\n",
    "[accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy) 指標 (詳細については、🤗 評価 [クイック ツアー](https://huggingface.co/docs/evaluate/a_quick_tour) を参照してくださいメトリクスをロードして計算する方法):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "accuracy = evaluate.load(\"accuracy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、予測とラベルを `compute` に渡して精度を計算する関数を作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    return accuracy.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで `compute_metrics`関数の準備が整いました。トレーニングを設定するときにこの関数に戻ります。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) を使用したモデルの微調整に慣れていない場合は、[こちら](https://huggingface.co/docs/transformers/main/ja/tasks/../training#train-with-pytorch-trainer) の基本的なチュートリアルをご覧ください。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "これでモデルのトレーニングを開始する準備が整いました。 [AutoModelForImageClassification](https://huggingface.co/docs/transformers/main/ja/model_doc/auto#transformers.AutoModelForImageClassification) を使用して ViT をロードします。ラベルの数と予想されるラベルの数、およびラベル マッピングを指定します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForImageClassification.from_pretrained(\n",
    "    checkpoint,\n",
    "    num_labels=len(labels),\n",
    "    id2label=id2label,\n",
    "    label2id=label2id,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "この時点で残っているステップは 3 つだけです。\n",
    "\n",
    "1. [TrainingArguments](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.TrainingArguments) でトレーニング ハイパーパラメータを定義します。 `image` 列が削除されるため、未使用の列を削除しないことが重要です。 `image` 列がないと、`pixel_values` を作成できません。この動作を防ぐには、`remove_unused_columns=False`を設定してください。他に必要なパラメータは、モデルの保存場所を指定する `output_dir` だけです。 `push_to_hub=True`を設定して、このモデルをハブにプッシュします (モデルをアップロードするには、Hugging Face にサインインする必要があります)。各エポックの終了時に、[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) は精度を評価し、トレーニング チェックポイントを保存します。\n",
    "2. トレーニング引数を、モデル、データセット、トークナイザー、データ照合器、および `compute_metrics` 関数とともに [Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) に渡します。\n",
    "3. [train()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.train) を呼び出してモデルを微調整します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"my_awesome_food_model\",\n",
    "    remove_unused_columns=False,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=5e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    gradient_accumulation_steps=4,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=3,\n",
    "    warmup_steps=0.1,\n",
    "    logging_steps=10,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"accuracy\",\n",
    "    push_to_hub=True,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=food[\"train\"],\n",
    "    eval_dataset=food[\"test\"],\n",
    "    processing_class=image_processor,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニングが完了したら、 [push_to_hub()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.push_to_hub) メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "画像分類用のモデルを微調整する方法の詳細な例については、対応する [PyTorch ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルを微調整したので、それを推論に使用できるようになりました。\n",
    "\n",
    "推論を実行したい画像を読み込みます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset(\"ethz/food101\", split=\"validation[:10]\")\n",
    "image = ds[\"image\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png\" alt=\"image of beignets\"/>\n",
    "</div>\n",
    "\n",
    "推論用に微調整されたモデルを試す最も簡単な方法は、それを [pipeline()](https://huggingface.co/docs/transformers/main/ja/main_classes/pipelines#transformers.pipeline) で使用することです。モデルを使用して画像分類用の`pipeline`をインスタンス化し、それに画像を渡します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.31856709718704224, 'label': 'beignets'},\n",
       " {'score': 0.015232225880026817, 'label': 'bruschetta'},\n",
       " {'score': 0.01519392803311348, 'label': 'chicken_wings'},\n",
       " {'score': 0.013022331520915031, 'label': 'pork_chop'},\n",
       " {'score': 0.012728818692266941, 'label': 'prime_rib'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "classifier = pipeline(\"image-classification\", model=\"my_awesome_food_model\")\n",
    "classifier(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "必要に応じて、`pipeline`の結果を手動で複製することもできます。\n",
    "\n",
    "\n",
    "画像プロセッサをロードして画像を前処理し、`input`を PyTorch テンソルとして返します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "import torch\n",
    "\n",
    "image_processor = AutoImageProcessor.from_pretrained(\"my_awesome_food_model\")\n",
    "inputs = image_processor(image, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "入力をモデルに渡し、ロジットを返します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForImageClassification\n",
    "\n",
    "model = AutoModelForImageClassification.from_pretrained(\"my_awesome_food_model\")\n",
    "with torch.no_grad():\n",
    "    logits = model(**inputs).logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最も高い確率で予測されたラベルを取得し、モデルの `id2label` マッピングを使用してラベルに変換します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'beignets'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted_label = logits.argmax(-1).item()\n",
    "model.config.id2label[predicted_label]"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
